{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 作业2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Anaconda3\\lib\\site-packages\\sklearn\\feature_extraction\\text.py:17: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n",
      "  from collections import Mapping, defaultdict\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'DESCR': 'mldata.org dataset: mnist-original', 'COL_NAMES': ['label', 'data'], 'target': array([0., 0., 0., ..., 9., 9., 9.]), 'data': array([[0, 0, 0, ..., 0, 0, 0],\n",
      "       [0, 0, 0, ..., 0, 0, 0],\n",
      "       [0, 0, 0, ..., 0, 0, 0],\n",
      "       ...,\n",
      "       [0, 0, 0, ..., 0, 0, 0],\n",
      "       [0, 0, 0, ..., 0, 0, 0],\n",
      "       [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)}\n"
     ]
    }
   ],
   "source": [
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import cross_val_score,GridSearchCV,cross_val_predict\n",
    "from sklearn.metrics import precision_score,recall_score,confusion_matrix\n",
    "from sklearn.datasets import fetch_mldata\n",
    "mnist = fetch_mldata('MNIST original',data_home='./')\n",
    "print(mnist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = mnist['data']\n",
    "y = mnist['target']\n",
    "X_train,X_test,y_train,y_test = X[:60000],X[60000:],y[:60000],y[60000:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "nrp = np.random.permutation(60000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = X_train[nrp]\n",
    "y_train = y_train[nrp]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train_6 = (y_train==6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
       "           metric_params=None, n_jobs=1, n_neighbors=5, p=2,\n",
       "           weights='uniform')"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kc_fier = KNeighborsClassifier()\n",
    "kc_fier.fit(X_train,y_train_6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([False, False, False, False, False, False, False, False, False,\n",
       "       False])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_digits = X_train[10:20]\n",
    "kc_fier.predict(test_digits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.99650017, 0.997     , 0.99664983])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_score = cross_val_score(kc_fier,X_train,y_train_6,cv=3)\n",
    "train_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train_predict = cross_val_predict(kc_fier,X_train,y_train_6,cv=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[53958,   124],\n",
       "       [   73,  5845]], dtype=int64)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confusion_matrix(y_train_6,y_train_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9792260010051935"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision_score(y_train_6,y_train_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.987664751605272"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "recall_score(y_train_6,y_train_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_test_pred = kc_fier.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_test_6 = (y_test==6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[9025,   17],\n",
       "       [  13,  945]], dtype=int64)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confusion_matrix(y_test_6,y_test_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9823284823284824"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision_score(y_test_6,y_test_pred)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9864300626304802"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "recall_score(y_test_6,y_test_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 网格搜索"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 5 folds for each of 6 candidates, totalling 30 fits\n",
      "[CV] n_neighbors=2, weights=uniform ..................................\n",
      "[CV]  n_neighbors=2, weights=uniform, score=0.9966669444212982, total=17.3min\n",
      "[CV] n_neighbors=2, weights=uniform ..................................\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed: 67.7min remaining:    0.0s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV]  n_neighbors=2, weights=uniform, score=0.9973335555370386, total=11.6min\n",
      "[CV] n_neighbors=2, weights=uniform ..................................\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed: 123.3min remaining:    0.0s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV] .... n_neighbors=2, weights=uniform, score=0.99725, total=11.6min\n",
      "[CV] n_neighbors=2, weights=uniform ..................................\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed: 179.1min remaining:    0.0s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV]  n_neighbors=2, weights=uniform, score=0.9958329860821735, total=12.0min\n",
      "[CV] n_neighbors=2, weights=uniform ..................................\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed: 234.1min remaining:    0.0s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV]  n_neighbors=2, weights=uniform, score=0.996999749979165, total=11.3min\n",
      "[CV] n_neighbors=2, weights=distance .................................\n",
      "[CV]  n_neighbors=2, weights=distance, score=0.997083576368636, total=11.3min\n",
      "[CV] n_neighbors=2, weights=distance .................................\n",
      "[CV]  n_neighbors=2, weights=distance, score=0.9970002499791684, total=11.3min\n",
      "[CV] n_neighbors=2, weights=distance .................................\n",
      "[CV]  n_neighbors=2, weights=distance, score=0.9975833333333334, total=12.9min\n",
      "[CV] n_neighbors=2, weights=distance .................................\n",
      "[CV]  n_neighbors=2, weights=distance, score=0.9959163263605301, total=12.7min\n",
      "[CV] n_neighbors=2, weights=distance .................................\n",
      "[CV]  n_neighbors=2, weights=distance, score=0.9968330694224519, total=12.5min\n",
      "[CV] n_neighbors=4, weights=uniform ..................................\n",
      "[CV]  n_neighbors=4, weights=uniform, score=0.9963336388634281, total=12.3min\n",
      "[CV] n_neighbors=4, weights=uniform ..................................\n",
      "[CV]  n_neighbors=4, weights=uniform, score=0.9975002083159736, total=13.7min\n",
      "[CV] n_neighbors=4, weights=uniform ..................................\n",
      "[CV]  n_neighbors=4, weights=uniform, score=0.9973333333333333, total=15.9min\n",
      "[CV] n_neighbors=4, weights=uniform ..................................\n",
      "[CV]  n_neighbors=4, weights=uniform, score=0.9970830902575215, total=16.3min\n",
      "[CV] n_neighbors=4, weights=uniform ..................................\n",
      "[CV]  n_neighbors=4, weights=uniform, score=0.9967497291440953, total=16.1min\n",
      "[CV] n_neighbors=4, weights=distance .................................\n",
      "[CV]  n_neighbors=4, weights=distance, score=0.9959170069160903, total=15.9min\n",
      "[CV] n_neighbors=4, weights=distance .................................\n",
      "[CV]  n_neighbors=4, weights=distance, score=0.9974168819265061, total=17.6min\n",
      "[CV] n_neighbors=4, weights=distance .................................\n",
      "[CV]  n_neighbors=4, weights=distance, score=0.9976666666666667, total=18.4min\n",
      "[CV] n_neighbors=4, weights=distance .................................\n",
      "[CV]  n_neighbors=4, weights=distance, score=0.996999749979165, total=18.1min\n",
      "[CV] n_neighbors=4, weights=distance .................................\n",
      "[CV]  n_neighbors=4, weights=distance, score=0.997166430535878, total=17.6min\n",
      "[CV] n_neighbors=6, weights=uniform ..................................\n",
      "[CV]  n_neighbors=6, weights=uniform, score=0.9960836596950254, total=18.9min\n",
      "[CV] n_neighbors=6, weights=uniform ..................................\n",
      "[CV]  n_neighbors=6, weights=uniform, score=0.9975835347054413, total=12.9min\n",
      "[CV] n_neighbors=6, weights=uniform ..................................\n",
      "[CV]  n_neighbors=6, weights=uniform, score=0.9975833333333334, total=13.3min\n",
      "[CV] n_neighbors=6, weights=uniform ..................................\n",
      "[CV]  n_neighbors=6, weights=uniform, score=0.9967497291440953, total=17.0min\n",
      "[CV] n_neighbors=6, weights=uniform ..................................\n",
      "[CV]  n_neighbors=6, weights=uniform, score=0.9968330694224519, total=13.7min\n",
      "[CV] n_neighbors=6, weights=distance .................................\n",
      "[CV]  n_neighbors=6, weights=distance, score=0.9958336805266228, total=16.1min\n",
      "[CV] n_neighbors=6, weights=distance .................................\n",
      "[CV]  n_neighbors=6, weights=distance, score=0.9973335555370386, total=17.9min\n",
      "[CV] n_neighbors=6, weights=distance .................................\n",
      "[CV]  n_neighbors=6, weights=distance, score=0.9976666666666667, total=13.9min\n",
      "[CV] n_neighbors=6, weights=distance .................................\n",
      "[CV]  n_neighbors=6, weights=distance, score=0.996999749979165, total=17.5min\n",
      "[CV] n_neighbors=6, weights=distance .................................\n",
      "[CV]  n_neighbors=6, weights=distance, score=0.9968330694224519, total=14.6min\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=1)]: Done  30 out of  30 | elapsed: 2142.0min finished\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "GridSearchCV(cv=5, error_score='raise',\n",
       "       estimator=KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
       "           metric_params=None, n_jobs=1, n_neighbors=5, p=2,\n",
       "           weights='uniform'),\n",
       "       fit_params=None, iid=True, n_jobs=1,\n",
       "       param_grid=[{'weights': ['uniform', 'distance'], 'n_neighbors': [2, 4, 6]}],\n",
       "       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',\n",
       "       scoring=None, verbose=5)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "param_grid = [{'weights':['uniform','distance'],'n_neighbors':[2,4,6]}]\n",
    "kc_fier = KNeighborsClassifier()\n",
    "grid_search = GridSearchCV(kc_fier,param_grid,cv=5,verbose=5)\n",
    "grid_search.fit(X_train,y_train_6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_grid_test_pred = grid_search.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([False, False, False, ..., False, False, False])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_grid_test_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[9025,   17],\n",
       "       [  12,  946]], dtype=int64)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confusion_matrix(y_test_6,y_grid_test_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9823468328141225"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision_score(y_test_6,y_grid_test_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9874739039665971"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "recall_score(y_test_6,y_grid_test_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
